perf(vllm): fuse MiniMax M3 BF16 EP experts on MI300X#1782
Conversation
|
Thanks for the contribution! For vLLM & SGLang, please ensure that your recipes is similar to the official vLLM recipes and/or the SGLang cookbook If it is not, please create a PR first before we can merge your single node PR into the master branch. Let's ensure that the documentation is first class such that the entire ML community can benefit from your hard work! Thank you PR authors are responsible for ensuring that after merging, all GitHub Action jobs fully pass. A lot of the time, failures are just flakes and simply re-running the failed jobs will fix it. If re-running failed jobs is attempted, PR authors are responsible for ensuring it passes. See GitHub's docs on re-running failed jobs: https://docs.github.com/en/actions/how-tos/manage-workflow-runs/re-run-workflows-and-jobs#re-running-failed-jobs-in-a-workflow As a rule of thumb, generally, PR authors should request a review & get a PR approval from the respective companies' CODEOWNERS before requesting a review from core maintainers. If additional help is needed, PR authors can reach out to core maintainers over Slack. |
|
Claude finished @Oseltamivir's task in 7m 0s —— View job Review of PR #1782LGTM — no blocking issues found. Summary: This is a well-structured incremental optimization for MiniMax M3 EP8 on MI300X. The patch adds two complementary paths (short-context native/BF16 dispatch via the oracle, long-context sparse BF16 emulation with fused SwiGLU), and the gating logic correctly ensures each path is only reached for the profiled configuration. Key observations:
|
Summary
long-context shape on MI300X
experts instead of 128 global experts
expert instead of the existing 64-row tile
activation kernel and the 2x-intermediate GEMM1 output
This PR is stacked on #1753 and contains only the incremental EP8 optimization.
It does not include the profiling branch, AITER allreduce/RMSNorm work,
temporary benchmark configuration, or
perf-changelog.yamlchanges.Profile basis
The six-point MI300X profile found expert GEMM1+GEMM2 at 30.31 ms for 1k/c256
and 28.10 ms for 8k/c256. After collective fusion, expert GEMMs remained the
largest classified 8k/c256 phase at 28.79 ms across 114 calls.
At c256, MiniMax M3 has about 216 active tokens and top-k 4, or 864 routed rows
globally. EP8 owns 16 of 128 experts per rank, leaving about 108 local rows,
roughly 6.75 rows per local expert. The existing BF16 config uses a 64-row M
tile, so it can execute about 1,024 padded rows per rank for roughly 108 useful
rows. Global alignment also creates blocks for remote experts that do no useful
GEMM work.
Profile report:
https://github.com/SemiAnalysisAI/InferenceX/blob/profiling/experimental/minimax_m3_mi300x_profile.md
First-principles changes
is based on 16 local experts, while the device counter remains authoritative.
BLOCK_SIZE_M=16, matching the observed route densityand reducing padded expert-row computation by up to 4x versus the 64-row
tile.
applies split SwiGLU-OAI before storing. This halves its BF16 output traffic
and removes a separate activation launch.
It avoids direct atomic accumulation, which the profile identified as a poor
fit for the c256 top-k-4 shape.
The path is gated to the exact gfx94x MiniMax M3 EP8 BF16 shape. gfx95x and
other models/configurations are unchanged.
Validation
Static and local validation:
python -m pytest utils/matrix_logic/ -q: 156 passedbash -n benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.shcompileall, andgit diff --checkexpert-map reduction correctness tests
MI300X serving validation is pending infrastructure recovery. The exact six-job
matrix (c1/c16/c256 for 1k1k and 8k1k) was dispatched four times, but every
attempt failed before GPU allocation because the Slurm controller was
unreachable:
https://github.com/SemiAnalysisAI/InferenceX/actions/runs/27569397626
Note
Medium Risk
Touches inference hot-path MoE kernels and backend selection with fused numerics and atomic top-k reduction, though scope is tightly gated to a specific gfx94x MiniMax-M3 EP8 shape.
Overview
Adds a second runtime patch (
minimaxm3_mi300x_ep_mxfp8.patch) to the MI300X MiniMax-M3 benchmark recipe and wires it in after the base MXFP8 patch, with oracle-marker checks so EP8 optimizations are applied idempotently.EP8 backend split by context length: For profiled gfx94x MiniMax-M3 EP8, short context (
max_model_len≤ 4096) now selects the native MXFP8 backend with the existing mixed native/BF16 expert policy extended to EP8; long context keeps emulation but switches to a new sparse local-route BF16 path instead of treating all EP as slow emulation-only.Long-context sparse BF16 path (emulation): When the exact EP8 shape matches, MoE runs only locally owned routes—expert alignment can use
num_local_expertsso buffers and padding are sized for ~16 local experts rather than 128 global—and uses 16-row grouped GEMM tiles. GEMM1 is fused with split SwiGLU-OAI via a newfused_moe_gated_kernel; GEMM2 still applies router weights in the expert GEMM, and the fused top-k reduction can skip re-multiplying weights (apply_weights=False).Native MXFP8 EP improvements: Grouped GEMM launch bounds use a local-expert-aware
_max_post_padded; GEMM2 can fuse top-k reduction with relaxed atomics; SwiGLU+MXFP8 quant gains a route-aware variant that only processes aligned local rows after GEMM1.Changes are gated to the profiled MiniMax-M3 EP8 configuration on gfx94x; other models and platforms are intended to be unchanged.
Reviewed by Cursor Bugbot for commit 16c596a. Bugbot is set up for automated code reviews on this repo. Configure here.